%  
%  This program replicates the simulation results reported in 
%  "Testing Monotonicity of Mean Potential Outcomes in a Continuous Treatment 
%   with High-Dimensional Data" 
%

clear; tic;
warning('off')
rseed=RandStream('mt19937ar','Seed',10220408);
RandStream.setGlobalStream(rseed);

% Parameters
B=1000;      % number of bootstrap replications
S=1000;      % number of simulaitons
p=100;        % number of regressors
beta=(1./(1:1:p))';
%beta=(1./((1:1:p).^2))';
alpha=0.1;  
eta=10^-6;
epsi=10^-6;
K_iter=5;
trim=0.025;   % eps or 0.025
TXs=0;            % polynomials
t_grid=0:0.01:1;     % grid over T

for jj=1:24
if jj==1; K=5; n=200; q1=4; D=1; end
if jj==2; K=5; n=200; q1=4; D=2; end
if jj==3; K=5; n=200; q1=4; D=3; end
if jj==4; K=5; n=200; q1=4; D=4; end
if jj==5; K=5; n=200; q1=4; D=5; end
if jj==6; K=5; n=200; q1=4; D=6; end
if jj==7; K=5; n=400; q1=8; D=1; end
if jj==8; K=5; n=400; q1=8; D=2; end
if jj==9; K=5; n=400; q1=8; D=3; end
if jj==10; K=5; n=400; q1=8; D=4; end
if jj==11; K=5; n=400; q1=8; D=5; end
if jj==12; K=5; n=400; q1=8; D=6; end
if jj==13; K=5; n=800; q1=16; D=1; end
if jj==14; K=5; n=800; q1=16; D=2; end
if jj==15; K=5; n=800; q1=16; D=3; end
if jj==16; K=5; n=800; q1=16; D=4; end
if jj==17; K=5; n=800; q1=16; D=5; end
if jj==18; K=5; n=800; q1=16; D=6; end
if jj==19; K=5; n=1200; q1=24; D=1; end
if jj==20; K=5; n=1200; q1=24; D=2; end
if jj==21; K=5; n=1200; q1=24; D=3; end
if jj==22; K=5; n=1200; q1=24; D=4; end
if jj==23; K=5; n=1200; q1=24; D=5; end
if jj==24; K=5; n=1200; q1=24; D=6; end

% nuisance parameter  
sn=sqrt(n);
an=0.15*log(n);
bn=0.85*(log(n)/(log(log(n))));
%an=sqrt(0.3*log(n));
%bn=sqrt(0.4*(log(n)/(log(log(n)))));
M=round(n^(2/3));

% Weight function Q
ln=0; for q=2:q1; ln=ln+q*(q-1)/2; end  
ind_ln=zeros(ln,4);
idx=1; 
for q=2:q1
    for i=1:q
        for j=i+1:q   
            ind=[q,q*(q-1)/2,j,i];
            ind_ln(idx,:)=ind;
            idx=idx+1;            
        end
    end
end
ind_q=(ind_ln(:,1).^(-2))./ind_ln(:,2);

Tn_KS=zeros(S,1);
cv_KS=zeros(S,1);
Tn_CvM=zeros(S,1);
cv_CvM=zeros(S,1);
cv_SUZ=zeros(S,1);

parfor s=1:S
if floor(s/10)*10==s; disp([jj,s]); end 

% DGP
Uy=randn(n,1);
Ut=randn(n,1);
Sigma=toeplitz((.5).^(0:p-1));     
X=mvnrnd(zeros(1,p),Sigma,n); 
Xb=X*beta;
T=(Xb+3.6)/7.2+0.5*Ut;
Y=Uy;
if D==1; Y=Uy; end
if D==2; Y=Xb.*T+T.^2+Xb+Uy; end
if D==3; Y=Xb.*T-T+Xb+Uy; end
if D==4; Y=Xb.*T+sin(pi*T)+Xb+Uy; end
if D==5; Y=Xb.*T-T/sqrt(n)+Xb+Uy; end
if D==6; Y=Xb.*T+(sin(pi*T))./sqrt(n)+Xb+Uy; end
if TXs==0; TX=[T,X]; end
if TXs==1; TX=[T,T.^2,T.^3,(T*ones(1,p)).*X,X,X.^2,X.^3]; end

% Cross-fitting
nK=floor(n/K);
Ki=1:1:K;
idx_Kn=kron(Ki',ones(nK,1));
n0=length(idx_Kn);
n1=n0-nK;
phat=zeros(n,1);
gamma=zeros(n,1);
B0_Ks=zeros(K,1);
B1_Ks=zeros(size(TX,2),K);

for k=1:K
    idx_k=(idx_Kn~=k);
    Y_k=Y(idx_k);
    T_k=T(idx_k);
    X_k=X(idx_k,:);
    n1=length(Y_k);
    idx_k0=(idx_Kn==k);
    T_k0=T(idx_k0);
    X_k0=X(idx_k0,:);
    
    % Conditional density esitmator p(T,X) 
    h_1=1.06*std(T_k)*n1^(-1/5); 
    lambda_global=1.1*norminv(1-(1/log(n1))/max(n1*h_1,p))*sqrt(n1);
    phat0=zeros(nK,nK);
    for i=1:nK
        temp1=func_LogitCross(T_k<=(T_k0(i)+h_1),X_k,X_k0,n1,p,lambda_global,K_iter);
        temp2=func_LogitCross(T_k<=(T_k0(i)-h_1),X_k,X_k0,n1,p,lambda_global,K_iter);
        phat0(:,i)=max(trim,abs(temp1-temp2)/(2*h_1));
    end
    phat(idx_k0)=diag(phat0);     
    
    % Estimate gamma(t,x) by LASSO
    TX_k=TX(idx_k,:);
    TX_k0=TX(idx_k0,:);
    [B_lasso,FitInfo]=lasso(TX_k,Y_k,'CV',10);
    idx_cv=FitInfo.IndexMinMSE;
    B0_cv=FitInfo.Intercept(idx_cv);
    B1_cv=B_lasso(:,idx_cv);
    gamma(idx_k0)=B0_cv+TX_k0*B1_cv;
    B0_Ks(k)=B0_cv;
    B1_Ks(:,k)=B1_cv;
end

% Test statistics
idx=1;
zero_ln=zeros(ln,1);
Tn1_ln=zero_ln;
Tn2_ln=zero_ln;
vn_ln=zero_ln;
sig_ln=zero_ln;
psi_ln=zero_ln;
phi_ln=zeros(n,ln);

for q=2:q1
    % Instrumental functions
    r=1/q;
    qi=(0:r:1);
    qn=length(qi);
    ind_T1=(T*ones(1,q)<=(ones(n,1)*qi(2:qn)));
    ind_T2=(T*ones(1,q)>=(ones(n,1)*qi(1:q)));
    ind_T=(ind_T1.*ind_T2)>0;

    % numerical intergration of gamma(t,x)
    t_list=(0:1/(q*(M-1)):1)';
    tn=length(t_list);
    ind_t0=(M-1)*(0:q)+1;
    ind_t1=(t_list*ones(1,q)<=(ones(tn,1)*t_list(ind_t0(2:qn))'));
    ind_t2=(t_list*ones(1,q)>=(ones(tn,1)*t_list(ind_t0(1:q))'));
    ind_t=(ind_t1.*ind_t2)/M;
    gamma_int0=zeros(n,tn);
    for t=1:tn
        T_t=t_list(t)*ones(n,1);  
        if TXs==0; TX_t=[T_t,X]; end
        if TXs==1; TX_t=[T_t,T_t.^2,T_t.^3,(T_t*ones(1,p)).*X,X,X.^2,X.^3]; end
        for k=1:K
            idx_k0=(idx_Kn==k);
            TX_tk0=TX_t(idx_k0,:);
            gamma_int0(idx_k0,t)=B0_Ks(k)+TX_tk0*B1_Ks(:,k);
        end    
    end
    gamma_int=gamma_int0*ind_t;

    % Estimation of v(l) and sigma
    phi0=(((Y-gamma)./phat)*ones(1,q)).*ind_T;
    vi0=phi0+gamma_int;
    vn0=mean(vi0);    
    for i=1:q
        vn2=vn0(i);
        phi2=phi0(:,i)+gamma_int(:,i)-vn2;
        for j=i+1:q
            vn1=vn0(j);
            phi1=phi0(:,j)+gamma_int(:,j)-vn1;
            vn=vn2-vn1;
            phi=phi2-phi1;
            sig2=mean(phi.^2);
            if q==2 && i==1 && j==2
                sig0=sqrt(sig2);
                sig=sig0;
            else
                sig=max([sqrt(sig2),epsi*sig0]);
            end
            Tn1=sn*vn/sig;
            Tn2=max([Tn1,0])^2;
            psi=-bn*(Tn1<-an);  
            Tn1_ln(idx,1)=Tn1;
            Tn2_ln(idx,1)=Tn2;
            vn_ln(idx,1)=vn;
            sig_ln(idx,1)=sig;
            psi_ln(idx,1)=psi;
            phi_ln(:,idx)=phi;
            idx=idx+1;
        end
    end
end
Tn_KS(s,1)=max(Tn1_ln);
Tn_CvM(s,1)=Tn2_ln'*ind_q;

% Multiplier null distribution
cvB_KS=zeros(B,1);
cvB_CvM=zeros(B,1);
for b=1:B
    %U=randn(n,1);
    U=2*sqrt(3)*rand(n,1)-sqrt(3); 
    phi_u=(1/sn)*phi_ln'*U;     
    Tn1_u=phi_u./sig_ln+psi_ln;
    Tn2_u=max([Tn1_u,zero_ln],[],2).^2;    
    cvB_KS(b)=max(Tn1_u);
    cvB_CvM(b)=Tn2_u'*ind_q;
end
cvB_KS=sort(cvB_KS);
cv_KS(s,1)=quantile(cvB_KS,1-alpha+eta)+eta;
cvB_CvM=sort(cvB_CvM);
cv_CvM(s,1)=quantile(cvB_CvM,1-alpha+eta)+eta;

% SUZ method
[cb1_L,cb1_R,cb2_L,cb2_R] = func_SUZ(Y,X,T,B,alpha,K_iter,t_grid);
cv_SUZ(s)=(min(cb1_R)<0)+0;

end % S

%disp([mean(Tn_KS>cv_KS)',mean(Tn_CvM>cv_CvM)',mean(cv_SUZ)]);
disp([mean(Tn_CvM>cv_CvM)',mean(cv_SUZ)]);

time=toc; disp(['time:    ',num2str(time)]);

if jj==1; save sim0520_n200_b1D1; end
if jj==2; save sim0520_n200_b1D2; end
if jj==3; save sim0520_n200_b1D3; end
if jj==4; save sim0520_n200_b1D4; end
if jj==5; save sim0520_n200_b1D5; end
if jj==6; save sim0520_n200_b1D6; end
if jj==7; save sim0520_n400_b1D1; end
if jj==8; save sim0520_n400_b1D2; end
if jj==9; save sim0520_n400_b1D3; end
if jj==10; save sim0520_n400_b1D4; end
if jj==11; save sim0520_n400_b1D5; end
if jj==12; save sim0520_n400_b1D6; end
if jj==13; save sim0520_n800_b1D1; end
if jj==14; save sim0520_n800_b1D2; end
if jj==15; save sim0520_n800_b1D3; end
if jj==16; save sim0520_n800_b1D4; end
if jj==17; save sim0520_n800_b1D5; end
if jj==18; save sim0520_n800_b1D6; end
if jj==19; save sim0520_n1200_b1D1; end
if jj==20; save sim0520_n1200_b1D2; end
if jj==21; save sim0520_n1200_b1D3; end
if jj==22; save sim0520_n1200_b1D4; end
if jj==23; save sim0520_n1200_b1D5; end
if jj==24; save sim0520_n1200_b1D6; end

end  % jj %



